import torch
import torch.nn as nn


class ZhnNetFpn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(768, 96, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(96)
        self.relu = lambda x: x * torch.tanh(nn.functional.softplus(x))
        # self.relu = nn.ReLU(inplace=True)

    def forward(self, x):  # Nx768x4x5
        x = self.conv(x)  # Nx96x4x5
        x = self.bn(x)
        x = self.relu(x)
        return x
